# -*- coding:utf-8 -*-

# 定义节点类,FP树由节点构成
# name 节点名称
# count 权值
# parent 父节点
# children 直接子节点集合
# link 与之平行的节点


class Node:
    # 初始化节点
    def __init__(self, name, count, parent):
        self.name = name
        self.count = count
        self.parent = parent
        self.children = {}
        self.link = None

    # 当前节点权值加1
    def increase(self, count):
        self.count += count

    # 显示树的结构
    def display(self, ind=1):
        print("  "*ind + self.name + ":" + str(self.count))
        for c in self.children.values():
            c.display(ind + 1)


'''
转换2维列表
转换为key为不可变集合对象,value初始值为1的字典
例如frozenset({"r", "z", "h", "j", "p"})
'''


def ready_data(data):
    tmp_dict = {}
    for i in data:
        tmp_dict[frozenset(i)] = 1
    return tmp_dict


'''
创建FP树，返回项头表
data_set数据结构为字典,key为frozenset,value初始值为1
min_support最小支持度(元素出现的次数)默认为3
'''


def create_tree(data_set, min_support=3):
    tmp_dict = {}
    table_dict = {}
    # 获取所有出现的元素,并且获得出现的次数
    for i in data_set:
        for j in i:
            tmp_dict[j] = tmp_dict.get(j, 0) + data_set[i]
    # 再进行一次循环,把小于最小支持度的元素过滤掉,得到频繁项集及其支持度的字典
    for k in tmp_dict:
        if tmp_dict[k] >= min_support:
            table_dict[k] = tmp_dict[k]
    # 只获取key，得到频繁项集
    frequent_set = set(table_dict.keys())
    if len(frequent_set) == 0:  # 如果都小于最小支持度,那么返回空
        return None, None
    for k in table_dict:  # 为项头表初始化数据
        table_dict[k] = [table_dict[k], None]
    # 初始化树,建立根节点
    tree = Node("NULL", 1, None)
    for key, value in data_set.items():  # 循环获取data_set中的(key, value)键值对
        localD = {}
        for item in key:  # 对key中的元素进行循环
            if item in frequent_set:  # 如果key中的元素是频繁项集
                localD[item] = table_dict[item][0]  # 把每个交易中的当前满足支持度的商品名称作为key,在所有交易中出现的次数作为value
        if len(localD) > 0:
            # 把满足支持度的商品按照支持度进行降序排列,为了生成FP树做准备
            ordered_items = [v[0] for v in sorted(localD.items(), key=lambda p: p[1], reverse=True)]
            # 把当前交易中频繁出现的商品加载到tree中,形成树状结构
            update_tree(ordered_items, tree, table_dict, value)
    return tree, table_dict


'''
添加各个节点,最终形成FP树
items每次要添加的节点集合
tree当前节点的父节点
header_table项头表
count新的节点权值都为1
'''


def update_tree(items, tree, header_table, count):
    if items[0] in tree.children:  # 如果第一个元素出现过
        tree.children[items[0]].increase(count)  # 权值加1
    else:
        # 没出现过则新增这个子节点
        tree.children[items[0]] = Node(items[0], count, tree)
        if header_table[items[0]][1] is None:
            header_table[items[0]][1] = tree.children[items[0]]
        else:
            update_header(header_table[items[0]][1], tree.children[items[0]])
    if len(items) > 1:  # 如果还有要处理的元素,例如：items=[a,b,c],处理完a,再处理b
        # 依次添加到FP树中
        update_tree(items[1::], tree.children[items[0]], header_table, count)


def update_header(nodeToTest, targetNode):
    while nodeToTest.link != None:
        nodeToTest = nodeToTest.link
    nodeToTest.link = targetNode


'''
获得指定节点到根节点的所有路径,路径上的节点名依次添加到列表中
'''


def backtrack(leaf_node, path_list: list):
    if leaf_node.parent is not None:  # 如果当前节点有父节点
        path_list.append(leaf_node.name)
        backtrack(leaf_node.parent, path_list)  # 调用自己,直到最高一级父节点(根节点除外,因为根节点的parent为None)


'''
寻找给定元素结尾的所有路径
'''


def find_prefix_path(tree_node):
    paths = {}
    while tree_node is not None:
        prefix_path = []
        backtrack(tree_node, prefix_path)
        if len(prefix_path) > 1:
            paths[frozenset(prefix_path[1:])] = tree_node.count
        tree_node = tree_node.link
    return paths


def mineTree(in_tree, header_table, min_support, prefix, frequent_list):
    bigL = [v[0] for v in sorted(header_table.items(), key=lambda p:p[0])]
    for name in bigL:
        new_frequent_set = prefix.copy()
        new_frequent_set.add(name)
        frequent_list.append(new_frequent_set)
        condition = find_prefix_path(header_table[name][1])
        con_tree, new_header = create_tree(condition, min_support)
        if new_header is not None:
            mineTree(con_tree, new_header, min_support, new_frequent_set, frequent_list)


'''
FP_Growth算法主函数
data_set数据类型为2维list
'''


def FP_Growth(data_set, min_support=3):
    init_set = ready_data(data_set)
    myFPtree, myHeaderTab = create_tree(init_set, min_support)
    freqItems = []
    mineTree(myFPtree, myHeaderTab, min_support, set([]), freqItems)
    return freqItems